import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from scipy.io import loadmat
from torch.utils.data import Dataset
from class_Angular_Spectrum import Angular_Spectrum
import imshift
import matplotlib.pyplot as plt
import csv
import scipy.io as sio
import os
import time
start_time = time.time()  # Record start time
import numpy as np

#gpu
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#parameters
lr=0.001
batch_size_train=1


class DataSet(Dataset):
    def __init__(self, path, transform=None):
        data = loadmat(path)['tensor']
        data = np.expand_dims(data, axis=3)
        self.data = np.transpose(data, (1, 2, 0, 3))
        self.transform = transform
    def __len__(self):
        return self.data.shape[3]
    def __getitem__(self, item):
        img = self.data[:,:,:,item]
        if self.transform:
            img = self.transform(img)
            img=img.to(torch.float32)
        return img,0

#input
transform = transforms.ToTensor()
train_dataset = DataSet(path='image3_8x8.mat',transform=transform)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size_train,shuffle=True,drop_last=True, num_workers=2)

#UPNN model
from YNet1 import ResU_Net
class pdnet(torch.nn.Module):
    def __init__(self,mask,wl, N_pixels, pixel_size, init_distance):
        super(pdnet, self).__init__()
        self.YNet_obj = ResU_Net(img_ch=64, output_ch=1, filters=[16,32, 64, 128, 256])
        self.YNet_probe = ResU_Net(img_ch=64, output_ch=1, filters=[16, 32, 64, 128, 256])
        self.mask = mask
        self.ASM = Angular_Spectrum(wl, N_pixels, pixel_size, init_distance)

        self.pixl_size = pixel_size
        self.step = 1.2
        self.K1 = 8
        self.K2 = 8
        self.K = self.K2 * self.K1
        shifts_1_range = torch.linspace(-self.step * (self.K1 - 1) / 2, self.step * (self.K1 - 1) / 2, self.K1)
        shifts_2_range = torch.linspace(-self.step * (self.K2 - 1) / 2, self.step * (self.K2 - 1) / 2, self.K2)

        shifts_1, shifts_2 = torch.meshgrid(shifts_1_range, shifts_2_range) #, indexing='ij'
        shifts_1 = shifts_1.t()
        shifts_2 = shifts_2.t()
        self.shifts_1 = shifts_1.reshape(self.K, 1).to(device)
        self.shifts_2 = shifts_2.reshape(self.K, 1).to(device)

    def forward(self,x):
        obj_P, obj_A = self.YNet_obj(x)
        pro_P, pro_A = self.YNet_probe(x)
        PA0    =  obj_A * torch.exp(1j * obj_P)
        Probe0 =  pro_A  * torch.exp(1j * pro_P)
        Probe0 = Probe0 * self.mask
        #pro_P = torch.angle(Probe0)
        #pro_A = torch.abs(Probe0)

        probe1 = Probe0.repeat(1, self.K, 1, 1)
        output_list = []
        for k in range(self.K):
            output = imshift.imshift(PA0[0,0], self.shifts_1[k] / self.pixl_size, self.shifts_2[k] / self.pixl_size,device)
            output_list.append(output)
        stacked_output  = torch.stack(output_list,dim=0)
        PA = stacked_output.unsqueeze(0)

        cub = PA * probe1
        E_asm = self.ASM(cub)
        output = torch.abs(E_asm).pow_(2)

       # output_min = output.min()
       # output_max = output.max()
       #  output = (output - output_min) / (output_max - output_min)
       # output = output *256
        return output,obj_P,obj_A,pro_P,pro_A


#THz ptychography parameters
wl = 116e-3 #mm
N_pixels = 512
pixel_size = 50e-3 #mm
distance= 20 #mm
K1 = 8
K2 = 8
K = K1*K2

#mask
mask = loadmat('probe_R=65_double.mat')['probe']
mask = np.expand_dims(mask, axis=2)
mask = torch.from_numpy(np.expand_dims(mask, axis=3)).permute(3,2,0,1).to(torch.complex64).to(device) ##########


#model
model = pdnet(mask,wl, N_pixels, pixel_size, distance).to(device)
#loss function
criterion = torch.nn.MSELoss().to(device)
#optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

epochs = 5000
train_loss_hist = []


dir_save = './results/UPNN/'
os.makedirs(dir_save, exist_ok=True)
#dir_save1 = dir_save + '1cycle/'
#os.makedirs(dir_save1, exist_ok=True)
dir_save2 = dir_save + '100cycle/'
os.makedirs(dir_save2, exist_ok=True)

header = ['no','cycle time','total time','read time','loss']
data_read_time = time.time()-start_time

from tqdm import tqdm
file_time = 0

for step, (images0, _) in enumerate(train_loader):
    images0 = images0.to(device)

#number of params
from thop import profile
flops, params = profile(model, inputs=(images0,))
print(f"FLOPs: {flops:,}")
def count_trainable_params(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)
print(f"trainable parameters: {count_trainable_params(model):,}")

for epoch in tqdm(range(epochs)):
    epoch = epoch + 1
    iter_start = time.time()
    ep_loss = 0
    model.train()

    noise = torch.randn_like(images0) * 0.05
    images = images0 + noise
    out_img,obj_P, obj_A, pro_P, pro_A = model(images)
    loss = criterion(out_img, images)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    iter_time = time.time() - iter_start
    file_time += iter_time
    values = [epoch, iter_time, file_time, data_read_time, loss.item()]

    #save csv
    path_csv = dir_save2 + 'Ynet' + '_sample' + '.csv'
    if os.path.isfile(path_csv) == False:
        file = open(path_csv, 'w', newline='')
        writer = csv.writer(file)
        writer.writerow(header)
        writer.writerow(values)
    else:
        file = open(path_csv, 'a', newline='')
        writer = csv.writer(file)
        writer.writerow(values)
    file.close()

    ep_loss = loss.item()
    train_loss_hist.append(ep_loss)
    print('training:',epoch,'|',step,'|',train_loss_hist[-1],'time:',file_time)

    if (epoch == 1) or (epoch == 10) or (epoch == 50) or ((epoch) % 100 == 0):

        plt.subplot(2, 4, 1)
        amp_np = obj_A[0,0,129:384, 129:384].detach().cpu().numpy()
        plt.imshow(amp_np, cmap='gray')
        plt.title('obj (amp)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 2)
        pha_np = obj_P[0,0,129:384, 129:384].detach().cpu().numpy()
        plt.imshow(pha_np, cmap='inferno')
        plt.title('obj (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 3)
        amp_np = obj_A[0,0,:,:].detach().cpu().numpy()
        plt.imshow(amp_np, cmap='gray')
        plt.title('obj (amp)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 4)
        pha_np = obj_P[0,0,:,:].detach().cpu().numpy()
        plt.imshow(pha_np, cmap='inferno')
        plt.title('obj (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')

        plt.subplot(2, 4, 5)
        amp_np_pro = pro_A[0,0,129:384, 129:384].detach().cpu().numpy()
        plt.imshow(amp_np_pro, cmap='gray')
        plt.title('cycle' + (str(epoch).zfill(6)))
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 6)
        pha_np_pro = pro_P[0,0,129:384, 129:384].detach().cpu().numpy()
        plt.imshow(pha_np_pro, cmap='inferno')
        # plt.title('probe (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 7)
        amp_np_pro = pro_A[0,0,:,:].detach().cpu().numpy()
        plt.imshow(amp_np_pro, cmap='gray')
        plt.title('cycle' + (str(epoch).zfill(6)))
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')
        plt.subplot(2, 4, 8)
        pha_np_pro = pro_P[0,0,:,:].detach().cpu().numpy()
        plt.imshow(pha_np_pro, cmap='inferno')
        # plt.title('probe (pha)')
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.axis('off')

        filename_plt = dir_save2 + 'Ynet' + '_cycle' + (
            str(epoch).zfill(6)) + '_results.png'
        plt.savefig(filename_plt)
        plt.show()

        amp_np = obj_A[0,0,:,:].detach().cpu().numpy()
        pha_np = obj_P[0,0,:,:].detach().cpu().numpy()
        amp_np_pro = pro_A[0,0,:,:].detach().cpu().numpy()
        pha_np_pro = pro_P[0,0,:,:].detach().cpu().numpy()
        filename = dir_save2 + 'Ynet' + '_cycle' + (str(epoch).zfill(6)) + '_results.mat'
        sio.savemat(filename, {'output_obj_A': amp_np, 'output_obj_P': pha_np,
                               'output_pro_A': amp_np_pro, 'output_pro_P': pha_np_pro})

        torch.save(model.state_dict(), dir_save2 + 'model.pth')





